
% inputs:
% mov: either the actual raw movie or just the path to the raw movie as a
% char array
% background: either the actual raw movie of the background or a path to
% the file
% metaname: char array of the path to the metadata file
% movtype: either 'ios' or 'gcamp' depending on which it is
% imagefiltering: two element logical array. First element is whether you
% want to the smoothing and second element is if you want to do a median
% filtering
% gcampcorrection: whether you want to do the bloodflow correction for the
% gcamp signal
% driftCorrectionTransform: optional input if you've already run the drift
% correction - it should be a character array of the path to the drift
% correction file
% stimtype: 'stationary' or 'moving' depending on if the stimulus is moving
% across the screen or staying in one spot
% can also be 'optogenetic' which will mean we're running that paradigm

% updates from v4.3
% changed to work with data off of Dalsa camera (no hot pixels, dropped
% frames)
%
% updates as of 6/18/2024:
% set up to work with metadata from optostim experiments, which have
% differences in laser power and not stim position or anything like that
function [avMovie_ch1, avMovie_ch2, vascReference, avHbO, avHbR,uniqueIdentifier,stimResponse_ch1,stimResponse_ch2, ch1_sorted, ch2_sorted,stdMovie_ch1,stdMovie_ch2,ch1_sortedZ,ch2_sortedZ,channelNormalization,baselinePixelValues] = parsePeriodicStimMoviesOPTO_v4_4(mov,background,metaname,movtype,imagefiltering,gcampcorrection,driftCorrectionTransform,wavelengths,windowmask,donormalize,stimtype,backgroundSubtractionType,normalizeSingleTrials,bgmetaname)

if nargin < 11 || isempty(stimtype)
    stimtype = 'stationary';
end
if nargin < 13 || isempty(normalizeSingleTrials)
    normalizeSingleTrials = true;
end
%avg background subtraction will subtract the average across the whole
%movie; 'frame' will subtract individual frames' background
if nargin < 12 || isempty(backgroundSubtractionType)
    backgroundSubtractionType = 'avg';      %'avg' or 'frame'    
end
if nargin < 14 || isempty(bgmetaname)
    bgmetaname = [];
end
switch backgroundSubtractionType
    case 'frame'
        if isempty(bgmetaname)
            fn = dir('*background*.mat');
            load(fn(1).name, 'laserpower')
        else
            load(bgmetaname, 'laserpower')
        end

        stimIdentifierBG = laserpower;
        [uniqueIdentifierBG,~,stimindBG] = unique(stimIdentifierBG);
end
%% some relevant inputs:
domedianfilt = imagefiltering(2);
imagefiltering = imagefiltering(1);
if nargin < 8 ||isempty(wavelengths)
    wavelengths = [636;530];
end
if isempty(gcampcorrection)
    gcampcorrection = false;
end
if nargin < 10 || isempty(donormalize)
    donormalize = false;
end
%% load in movie and some relevant metadata
if isa(mov(1),'char')
    rootname = mov;
    mov = loadTiffStack2(rootname);
end
if nargin < 9 || isempty(windowmask)
    windowmask = true(size(mov,1),size(mov,2));
end

if isa(background(1),'char')
    tempbackground = loadTiffStack2(background);
else
    tempbackground = background;
end
clear background
refframe = round(size(mov,3)/4.1);        %reference frame for image stabilization
switch stimtype
    case 'optogenetic'
        load(metaname, 'totalFrameTime','stimONframe','stimOFFframe','numframes','laserpower','numrepeats')
        fps = 1/(totalFrameTime/1000);
        preStimTime = (stimONframe)/fps;
        framesPerStim = numframes;
    case 'stationary'
        load(metaname, 'blueLEDframes', 'fps', 'framesPerStim', 'greenLEDframes','rectPositions','movindicator','numrepeats','stimparams','preStimTime','totalFrameTime','redLEDframes')
    case 'moving'
        load(metaname, 'blueLEDframes', 'fps', 'framesPerStim', 'greenLEDframes','rectPositions','movindicator','numrepeats','stimparams','preStimTime','totalFrameTime','redLEDframes')
end


switch movtype
    case 'gcamp'
        blueLEDframes = 1;
        greenLEDframes = 1;
        disp('HARD CODED NUMBER OF FRAMES FOR EACH LED!');
    case 'ios'
        %keep loaded values
    case 'ios1ch'
        %keep loaded values
end

switch movtype
    case 'ios1ch'
        background{2} = tempbackground(:,:,:);
        background{1} = zeros(size(background{2}));
    otherwise
        background{1} = tempbackground(:,:,1:2:end);
        background{2} = tempbackground(:,:,2:2:end);
end
switch backgroundSubtractionType
    case 'frame'
        %do frame by frame background subtraction; in this case don't need
        %to do anything

    case 'avg'
        %do average background subtraction
        background{1} = mean(background{1},3);
        background{2} = mean(background{2},3);
end
clear tempbackground

%% average relevant frames:
switch movtype
    case 'gcamp'
        framesPerTimepoint = (greenLEDframes + blueLEDframes);
    case 'ios'
        framesPerTimepoint = (greenLEDframes + redLEDframes);
    case 'ios1ch'
        framesPerTimepoint = 1;
end
singleChannelFrames = size(mov,3)/framesPerTimepoint;
ch1 = zeros(size(mov,1),size(mov,2),singleChannelFrames);   % order matters here - this should be the blue LED channel
ch2 = zeros(size(mov,1),size(mov,2),singleChannelFrames);
% this block will allow you to do some frame averaging, though you won't
% necessarily do any:
switch movtype
    case 'gcamp'
        for n = 1:size(ch1,3)
            cur = ((n-1)*(framesPerTimepoint) + 1):((n-1)*(framesPerTimepoint) + blueLEDframes);
            ch1(:,:,n) = mean(mov(:,:,cur),3);    
            cur = ((n-1)*(framesPerTimepoint) + 1+blueLEDframes):...
                ((n-1)*(framesPerTimepoint) + blueLEDframes+greenLEDframes);
            ch2(:,:,n) = mean(mov(:,:,cur),3);
        end
    case 'ios'
        for n = 1:size(ch1,3)
            cur = ((n-1)*(framesPerTimepoint) + 1):((n-1)*(framesPerTimepoint) + redLEDframes);
            ch1(:,:,n) = mean(mov(:,:,cur),3);    
            cur = ((n-1)*(framesPerTimepoint) + 1+redLEDframes):...
                ((n-1)*(framesPerTimepoint) + redLEDframes+greenLEDframes);
            ch2(:,:,n) = mean(mov(:,:,cur),3);
        end
    case 'ios1ch'
        for n = 1:size(ch2,3)
            cur = ((n-1)*(framesPerTimepoint) + 1):((n)*(framesPerTimepoint));
            ch2(:,:,n) = mean(mov(:,:,cur),3);    
        end
end


%% do image stabilization:
if nargin < 7 || isempty(driftCorrectionTransform)
    [temp,imtrans] = driftCorrectWidefield(ch2,refframe);
    save(strrep(metaname,'.mat','_driftCorrection.mat'),'imtrans','refframe');
    %% get a low resolution image of the vasculature from channel 2 for later referencing:
    vascReference = mean(temp(:,:,(refframe-3):(refframe+3)),3);
    temp = prctile(vascReference(:),5);
    vascReference = vascReference.*windowmask;
    vascReference(~windowmask) = temp;
elseif isa(driftCorrectionTransform,'double') && driftCorrectionTransform == 0   %this will skip image registration
    temp = ch2(:,:,(refframe-3):(refframe+3));
    vascReference = mean(temp,3);
    temp = prctile(vascReference(:),5);
    vascReference = vascReference.*windowmask;
    vascReference(~windowmask) = temp;
else
    load(driftCorrectionTransform)      %this should also load in the reference frame
    temp = ch2(:,:,(refframe-3):(refframe+3));
    for n = 1:size(temp,3)
        temp(:,:,n) = imwarp(temp(:,:,n),imtrans(refframe-4+n),'OutputView',imref2d([size(ch2,1),size(ch2,2)]));  
    end
    vascReference = mean(temp,3);
    temp = prctile(vascReference(:),5);
    vascReference = vascReference.*windowmask;
    vascReference(~windowmask) = temp;
end


%% do additional median filtering to help remove hotpixel issues:
% in previos version median filtering was done but I think it can be nixed
% here
%% parse out which stimuli correspond to which movie
switch stimtype
    case 'optogenetic'
        stimlength = (stimOFFframe-stimONframe)/fps;
        fps_effective = fps/framesPerTimepoint;
    otherwise
        if isfield(stimparams{1},'stimtype')
            stimlength = stimparams{1}.totalStimLength;
        elseif isfield(stimparams{1},'circlesize') && ~isfield(stimparams{1},'duration')
            stimlength = stimparams{1}.totalStimLength;
        else
            stimlength = stimparams{1}.duration*60;
        end
        if ~exist('totalFrameTime','Var') || isnan(totalFrameTime)
            fps_effective = fps/framesPerTimepoint;
        else
            fps_effective = 1000/totalFrameTime;
        end
end
framesPerStim_effective = framesPerStim/framesPerTimepoint;
stimstart = round(preStimTime*fps_effective);
baselineframes = 2:(stimstart-1);
stimframesGCaMP = (stimstart + round(.2*fps_effective)):(stimstart+round(stimlength*fps_effective)-1);
stimframesIOS = ceil(stimstart + round(1.5*fps_effective)):(stimstart+round(stimlength*fps_effective)+.5*fps_effective);
stimframesGCaMP = stimframesGCaMP(stimframesGCaMP<(framesPerStim/framesPerTimepoint));
stimframesIOS = stimframesIOS(stimframesIOS<(framesPerStim/framesPerTimepoint));

switch stimtype
    case 'stationary'
        stimIdentifier = zeros(3,size(rectPositions,2),size(rectPositions,3));
        stimIdentifier(1:2,:,:) = rectPositions;
        for n = 1:size(movindicator,1)
            stimIdentifier(3,n,:) = movindicator(n,:);
        end
        [M,N,P] = size(stimIdentifier);
        a = reshape(stimIdentifier,M*N,P)';
    case 'moving'

        stimIdentifier = zeros(3,size(rectPositions,2),size(rectPositions,3));
        
        stimIdentifier(1:2,:,:) = rectPositions(:,:,:,1);       %this assumes starting position uniquely determines each trajectory here
        for n = 1:size(movindicator,1)
            stimIdentifier(3,n,:) = movindicator(n,:);
        end
    case 'optogenetic'
        stimIdentifier = laserpower;
        a = stimIdentifier';
end

[uniqueIdentifier,~,stimind] = unique(a,'rows');
if size(stimIdentifier,3) ~= 1
    uniqueIdentifier = reshape(uniqueIdentifier',M,N,[]);
end

% resort the data into cell arrays. Each element of the cell array is all
% the data from that particular stimulation type. 

switch stimtype
    case 'optogenetic'
        ch1_sorted = cell(1,size(uniqueIdentifier,1));      %sorting the stimuli into their corresponding bins
        ch2_sorted = cell(1,size(uniqueIdentifier,1));
        ch1_sortedZ = cell(1,size(uniqueIdentifier,1));     %calculating the z-score (based on the baseline frames) of that specific stim
        ch2_sortedZ = cell(1,size(uniqueIdentifier,1));
        originalFrameIdx = cell(1,size(uniqueIdentifier,1));
    otherwise
        ch1_sorted = cell(1,size(uniqueIdentifier,3));      %sorting the stimuli into their corresponding bins
        ch2_sorted = cell(1,size(uniqueIdentifier,3));
        ch1_sortedZ = cell(1,size(uniqueIdentifier,3));     %calculating the z-score (based on the baseline frames) of that specific stim
        ch2_sortedZ = cell(1,size(uniqueIdentifier,3));
        originalFrameIdx = cell(1,size(uniqueIdentifier,3));    %this variable will store the original frame index so you can do the drift correction at the very end after all the other processing is done
end
for n = 1:length(ch1_sorted)
    ch1_sorted{n} = zeros(size(ch1,1),size(ch1,2),framesPerStim_effective,numrepeats);
    ch2_sorted{n} = zeros(size(ch2,1),size(ch2,2),framesPerStim_effective,numrepeats);
    originalFrameIdx{n} = zeros(framesPerStim_effective,numrepeats);
end

for n = 1:length(ch2_sorted)
    curind = find(stimind == n);
    for n2 = 1:length(curind)
        ind = ((curind(n2)-1)*framesPerStim_effective + 1):(curind(n2)*framesPerStim_effective);
        ch1_sorted{n}(:,:,:,n2) = ch1(:,:,ind);
        ch2_sorted{n}(:,:,:,n2) = ch2(:,:,ind);
        originalFrameIdx{n}(:,n2) = ind;
    end
    % subtract out background signal
    switch backgroundSubtractionType
        case 'avg'
            ch1_sorted{n} = ch1_sorted{n} - repmat(background{1},1,1,size(ch1_sorted{n},3),size(ch1_sorted{n},4));
            ch2_sorted{n} = ch2_sorted{n} - repmat(background{2},1,1,size(ch2_sorted{n},3),size(ch2_sorted{n},4));
        case 'frame'
            if strcmp(stimtype,'optogenetic')  %for optogenetic, the background is only at one laser power since you're not acquiring during laser on anyway
                curbg = background{2};  
                ch2_sorted{n} = ch2_sorted{n} - repmat(curbg,1,1,1,size(ch2_sorted{n},4));
            else
                curbgind = find(stimindBG == n);
                ind = ((curbgind(1)-1)*framesPerStim_effective + 1):(curbgind(1)*framesPerStim_effective);
                curbg = background{1}(:,:,ind);
                ch1_sorted{n} = ch1_sorted{n} - repmat(curbg,1,1,1,size(ch1_sorted{n},4));
                
                curbg = background{2}(:,:,ind);
                ch2_sorted{n} = ch2_sorted{n} - repmat(curbg,1,1,1,size(ch2_sorted{n},4));
            end
    end
    
end
%clear ch1 and ch2 for memory considerations:
channelsizes = [size(ch1,1),size(ch1,2)];
clear ch1 ch2
disp('stimulus parsing done')
%% get ratiometric value to the  baseline
baselinePixelValues = cell(2,length(ch2_sorted));
for n = 1:length(ch2_sorted)
    if ~strcmp(movtype,'ios1ch')
        curbaseline = ch1_sorted{n}(:,:,baselineframes,:);
        curbaseline = mean(curbaseline,3);baselinePixelValues{1,n} = curbaseline;
        curbaseline = repmat(curbaseline,1,1,framesPerStim_effective,1);
        if normalizeSingleTrials
            ch1_sorted{n} = ch1_sorted{n}./curbaseline;                             
        end
    end
    curbaseline = ch2_sorted{n}(:,:,baselineframes,:);
    curbaseline = mean(curbaseline,3);baselinePixelValues{2,n} = curbaseline;
    curbaseline = repmat(curbaseline,1,1,framesPerStim_effective,1);
    if normalizeSingleTrials
        ch2_sorted{n} = ch2_sorted{n}./curbaseline;
    end
    
    % remove pixels outside of the windowmask:
    temp =  repmat(windowmask,1,1,size(ch2_sorted{n},3),size(ch2_sorted{n},4));
    ch1_sorted{n}(~temp) = 1;
    ch2_sorted{n}(~temp) = 1;
end

%% run a round of hot pixel removal to take care of outliers:
for n = 1:length(ch1_sorted)
    for n2 = 1:size(ch1_sorted{n},3)
        for n3 = 1:size(ch1_sorted{n},4)
            temp = ch1_sorted{n}(:,:,n2,n3);
            temp(isnan(temp)) = Inf;
            temp = removeHotPixels(temp);
            ch1_sorted{n}(:,:,n2,n3) = temp;
            
            temp = ch2_sorted{n}(:,:,n2,n3);
            temp(isnan(temp)) = Inf;
            temp = removeHotPixels(temp);
            temp = removeHotPixels(temp);       %run twice to more effect
            ch2_sorted{n}(:,:,n2,n3) = temp;
        end
    end
end

disp('hot pixel removal done')

%% run correction for fluorescence if applicable:
if gcampcorrection
    if strcmp(movtype,'gcamp')
        for n = 1:length(ch1_sorted)
            ch1_sorted{n} = ch1_sorted{n}./ch2_sorted{n};
        end
    end
end
%% if doing multispectrcal IOS, can do the hemoglobin concentration math here:
clear mov       %clear the movie array just to keep a sane amount of memory usage
if strcmp(movtype,'ios')
    load('iosLightPropagationProperties.mat', 'DPFx_cm','HbExtinctionCoeff');
    [deltaHbO, deltaHbR] = gethemoglobin(ch1_sorted,ch2_sorted,wavelengths,DPFx_cm,HbExtinctionCoeff);
else
    deltaHbO = nan;deltaHbR = nan;
end
disp('multispectral calculations done')
%% run spatial smoothing if applicable:
if imagefiltering
    switch movtype
        case 'gcamp'
            for n = 1:length(ch1_sorted)
                for n2 = 1:size(ch1_sorted{n},3)
                    for n3 = 1:size(ch1_sorted{n},4)
                        temp = imgaussfilt(ch1_sorted{n}(:,:,n2,n3),10); 
                        ch1_sorted{n}(:,:,n2,n3) = temp;
                        
                        temp = imgaussfilt(ch2_sorted{n}(:,:,n2,n3),2); 
                        ch2_sorted{n}(:,:,n2,n3) = temp;
                    end
                end
            end
        case 'ios'
            for n = 1:length(ch1_sorted)
                for n2 = 1:size(ch1_sorted{n},3)
                    for n3 = 1:size(ch1_sorted{n},4)
                        temp = imgaussfilt(ch1_sorted{n}(:,:,n2,n3),2); 
                        ch1_sorted{n}(:,:,n2,n3) = temp;
                        
                        temp = imgaussfilt(ch2_sorted{n}(:,:,n2,n3),2); 
                        ch2_sorted{n}(:,:,n2,n3) = temp;
                        try
                            deltaHbO{n}(:,:,n2,n3) = imgaussfilt(deltaHbO{n}(:,:,n2,n3),2); 
                            deltaHbR{n}(:,:,n2,n3) = imgaussfilt(deltaHbR{n}(:,:,n2,n3),2); 
                        catch
                            deltaHbO{n}(:,:,n2,n3) = 1; 
                            deltaHbR{n}(:,:,n2,n3) = 1; 
                        end
                    end
                end
            end
        case 'ios1ch'
            for n = 1:length(ch2_sorted)
                for n2 = 1:size(ch2_sorted{n},3)
                    for n3 = 1:size(ch2_sorted{n},4)
                
                        temp = imgaussfilt(ch2_sorted{n}(:,:,n2,n3),2); 
                        ch2_sorted{n}(:,:,n2,n3) = temp;
                       
                    end
                end
            end
    end
end
disp('spatial filtering done')

%% get normalization quantities for each channel:
ch1max = 0;
ch2max = 0;
for n = 1:length(ch2_sorted)
    cur = ch1_sorted{n}(:,:,2:end,:);
    cur = prctile(cur(:),99);
    if cur >ch1max
        ch1max = cur;
    end
    cur = ch2_sorted{n}(:,:,2:end,:);
    cur = prctile(cur(:),99);
    if cur >ch2max
        ch2max = cur;
    end
end
disp('normalization amounts quantified')

%% now can do drift correction since the image processing is done:
if ~isempty(driftCorrectionTransform) && isa(driftCorrectionTransform,'double') && driftCorrectionTransform == 0
    disp('skipping image regsitration')
else
    f = waitbar(0,'percent movie registered');
    for n = 1:length(ch1_sorted)
        for n2 = 1:size(ch1_sorted{n},3)
            for n3 = 1:size(ch1_sorted{n},4)
                curind = originalFrameIdx{n}(n2,n3);
                try
                ch1_sorted{n}(:,:,n2,n3) = imwarp(ch1_sorted{n}(:,:,n2,n3),imtrans(curind),'OutputView',imref2d(channelsizes));    
                
                catch
                    disp('unable to register channel 1 - hopefully you do not need it!')
                end
                ch2_sorted{n}(:,:,n2,n3) = imwarp(ch2_sorted{n}(:,:,n2,n3),imtrans(curind),'OutputView',imref2d(channelsizes));
                if isa(deltaHbO,'cell')
                    deltaHbO{n}(:,:,n2,n3) = imwarp(deltaHbO{n}(:,:,n2,n3),imtrans(curind),'OutputView',imref2d(channelsizes));
                    deltaHbR{n}(:,:,n2,n3) = imwarp(deltaHbR{n}(:,:,n2,n3),imtrans(curind),'OutputView',imref2d(channelsizes));
                end
            end
        end
        waitbar(n/length(ch1_sorted),f);
    end
    close(f)
end
disp('drift correction applied')


%% calculate z-scored trajectories:
for n = 1:length(ch1_sorted)
    curbaseline = ch1_sorted{n}(:,:,baselineframes,:);
    curbaselinestd = std(curbaseline,0,3);
    curbaseline = mean(curbaseline,3);
    curbaseline = repmat(curbaseline,1,1,framesPerStim_effective,1);
    curbaselinestd = repmat(curbaselinestd,1,1,framesPerStim_effective,1);
    ch1_sortedZ{n} = (ch1_sorted{n}-curbaseline)./curbaselinestd;
    
    curbaseline = ch2_sorted{n}(:,:,baselineframes,:);
    curbaselinestd = std(curbaseline,0,3);
    curbaseline = mean(curbaseline,3);
    curbaseline = repmat(curbaseline,1,1,framesPerStim_effective,1);
    curbaselinestd = repmat(curbaselinestd,1,1,framesPerStim_effective,1);
    ch2_sortedZ{n} = (ch2_sorted{n}-curbaseline)./curbaselinestd;
    
    temp =  repmat(windowmask,1,1,size(ch1_sorted{n},3),size(ch1_sorted{n},4));
    ch1_sortedZ{n}(~temp) = 0;
    ch2_sortedZ{n}(~temp) = 0;
end
    
if donormalize && normalizeSingleTrials
    for n = 1:length(ch1_sorted)
        ch1_sorted{n} = ch1_sorted{n}/ch1max;
        ch2_sorted{n} = ch2_sorted{n}/ch2max;
    end
end
channelNormalization = [ch1max;ch2max];


%% get average stim movie
avMovie_ch1 = cell(size(ch1_sorted));
avMovie_ch2 = cell(size(ch2_sorted));
stdMovie_ch1 = cell(size(ch1_sorted));
stdMovie_ch2 = cell(size(ch2_sorted));
for n = 1:length(avMovie_ch1)
    avMovie_ch1{n} = median(ch1_sorted{n},4);
    avMovie_ch2{n} = median(ch2_sorted{n},4);
    stdMovie_ch1{n} = std(ch1_sorted{n},0,4);
    stdMovie_ch2{n} = std(ch2_sorted{n},0,4);
end 
if isa(deltaHbO,'cell')
    avHbO = cell(size(deltaHbO));
    avHbR = cell(size(deltaHbR));
    for n = 1:length(avHbO)
        avHbO{n} = mean(deltaHbO{n},4);
        avHbR{n} = mean(deltaHbR{n},4);
    end
else
    avHbO = nan;
    avHbR = nan;
end
    
%% get average peak responses:
stimResponse_ch1 = zeros(channelsizes(1),channelsizes(2),length(avMovie_ch1));
stimResponse_ch2 = zeros(channelsizes(1),channelsizes(2),length(avMovie_ch2));
for n = 1:length(avMovie_ch1)
    switch movtype
        case 'gcamp'
            stimResponse_ch1(:,:,n) = mean(avMovie_ch1{n}(:,:,stimframesGCaMP),3);
            stimResponse_ch2(:,:,n) = mean(avMovie_ch2{n}(:,:,stimframesIOS),3);
        case 'ios'
            stimResponse_ch1(:,:,n) = mean(avMovie_ch1{n}(:,:,stimframesIOS),3);
            stimResponse_ch2(:,:,n) = mean(avMovie_ch2{n}(:,:,stimframesIOS),3);
        case 'ios1ch'
            stimResponse_ch2(:,:,n) = mean(avMovie_ch2{n}(:,:,stimframesIOS),3);
    end
end



function [deltaHbO, deltaHbR] = gethemoglobin(lambda1,lambda2,wavelengths,DPFx_cm,HbExtinctionCoeff)
% convenience function for iterating through the intensity arrays and
% calculating hemoglobin concentration differences...
deltaHbO = cell(size(lambda1));
deltaHbR = cell(size(lambda1));
curinten = cell(1,2);
for n = 1:length(lambda1)
    curinten{1} = lambda1{n}; curinten{2} = lambda2{n};
    [deltaHbO{n}, deltaHbR{n}] = iosReflectance2concentration_v1_0(curinten,wavelengths,DPFx_cm,HbExtinctionCoeff);
end

return;


    